package io.gitee.sections.sequence.core.dao;

import io.gitee.sections.sequence.core.SequenceDefinition;
import io.gitee.sections.sequence.core.exp.SequenceException;

import java.sql.*;
import java.util.Objects;
import java.util.Optional;

public abstract class AbstractDatabaseAccessor implements DataAccessor {
    protected ThreadLocal<Connection> threadLocalConnection = new ThreadLocal<>();

    protected abstract Connection getConnection() throws ClassNotFoundException, SQLException;

    @Override
    public void close() {
        Connection connection = threadLocalConnection.get();
        if (!Objects.isNull(connection)) {
            threadLocalConnection.remove();
            try {
                connection.close();
            } catch (SQLException e) {
                throw new SequenceException("fail to close connection", e);
            }
        }
    }

    @Override
    public Optional<SequenceDefinition> find(String key) {
        String sql = "SELECT `INITIAL`,`STEP_SIZE`,`CACHE_SIZE`, `CACHE_MODE` FROM `SEQ_SEQUENCE` WHERE `KEY` = ?";
        try (PreparedStatement statement = requireConnection().prepareStatement(sql)) {
            statement.setString(1, key);
            ResultSet resultSet = statement.executeQuery();
            if (resultSet.next()) {
                long initial = resultSet.getLong(1);
                int stepSize = resultSet.getInt(2);
                int cacheSize = resultSet.getInt(3);
                String cacheMode = resultSet.getString(4);
                return Optional.of(new SequenceDefinition(key, initial, stepSize, cacheSize, cacheMode));
            } else {
                return Optional.empty();
            }
        } catch (Exception e) {
            final String msg = String.format("fail to find definition of sequence %s", key);
            throw new SequenceException(msg, e);
        }
    }

    @Override
    public void insert(SequenceDefinition definition) {
        String sql = "INSERT INTO `SEQ_SEQUENCE`(`KEY`,`INITIAL`,`STEP_SIZE`,`CACHE_SIZE`, `CACHE_MODE`, `NUMBER`) VALUES (?, ?, ?, ?, ?, ?)";
        try (PreparedStatement statement = requireConnection().prepareStatement(sql)) {
            statement.setString(1, definition.getKey());
            statement.setLong(2, definition.getInitial());
            statement.setLong(3, definition.getStepSize());
            statement.setLong(4, definition.getCacheSize());
            statement.setString(5, definition.getCacheMode());
            statement.setLong(6, definition.getInitial());
            statement.execute();
        } catch (SQLIntegrityConstraintViolationException e) {
            //do nothing because sequence already exists
        } catch (Exception e) {
            final String msg = String.format("fail to insert sequence %s", definition.toString());
            throw new SequenceException(msg, e);
        }
    }

    @Override
    public void update(SequenceDefinition definition) {
        String sql = "UPDATE `SEQ_SEQUENCE` SET `INITIAL` = ?, `STEP_SIZE` = ? , `CACHE_SIZE` = ?, `CACHE_MODE` = ? WHERE `KEY` = ?";
        try (PreparedStatement statement = requireConnection().prepareStatement(sql)) {
            statement.setLong(1, definition.getInitial());
            statement.setInt(2, definition.getStepSize());
            statement.setInt(3, definition.getCacheSize());
            statement.setString(4, definition.getCacheMode());
            statement.setString(5, definition.getKey());
            statement.execute();
        } catch (Exception e) {
            final String msg = String.format("fail to update sequence %s", definition.toString());
            throw new SequenceException(msg, e);
        }
    }

    @Override
    public long grow(SequenceDefinition definition) {
        try {
            this.beginTransaction();
            long oldNumber = getNumberForUpdate(definition.getKey());
            long newNumber = oldNumber + definition.getStepSize() * definition.getCacheSize();
            updateNumber(definition.getKey(), oldNumber, newNumber);
            this.commitTransaction();
            return newNumber;
        } catch (Exception e) {
            final String msg = String.format("fail to load number of sequence %s", definition.toString());
            throw new SequenceException(msg, e);
        }
    }

    private long getNumberForUpdate(String key) {
        String sql = "SELECT `NUMBER` FROM `SEQ_SEQUENCE` WHERE `KEY` = ? FOR UPDATE";
        try (PreparedStatement statement = requireConnection().prepareStatement(sql)) {
            statement.setString(1, key);
            ResultSet resultSet = statement.executeQuery();
            if (resultSet.next()) {
                return resultSet.getLong(1);
            } else {
                final String msg = String.format("fail to get current number of sequence %s", key);
                throw new SequenceException(msg);
            }
        } catch (Exception e) {
            final String msg = String.format("fail to get current number of sequence %s", key);
            throw new SequenceException(msg, e);
        }
    }

    private void updateNumber(String key, Long oldNumber, Long newNumber) {
        String sql = "UPDATE `SEQ_SEQUENCE` SET `NUMBER` = ? WHERE `NUMBER` =? AND `KEY` = ?";
        try (PreparedStatement statement = requireConnection().prepareStatement(sql)) {
            statement.setLong(1, newNumber);
            statement.setLong(2, oldNumber);
            statement.setString(3, key);
            statement.executeUpdate();
        } catch (Exception e) {
            final String msg = String.format("fail to update sequence %s", key);
            throw new SequenceException(msg, e);
        }
    }

    private Connection requireConnection() throws SQLException, ClassNotFoundException {
        Connection connection = threadLocalConnection.get();
        if (Objects.isNull(connection)) {
            connection = getConnection();
            threadLocalConnection.set(connection);
        }
        return connection;
    }

    private void beginTransaction() throws SQLException, ClassNotFoundException {
        requireConnection().setAutoCommit(false);
    }

    private void commitTransaction() throws SQLException {
        threadLocalConnection.get().commit();
    }
}
